import math
import copy
import gym
import random
import numpy as np
import statistics
import pickle

# Import your updated custom/stochastic envs
import Continuous_CartPole      # e.g. same ID "Continuous-CartPole-v0" but now stochastic
import Continuous_Pendulum      # "Pendulum-v1" but optional noise
import continuous_mountain_car  # "MountainCarContinuous-v0"
import continuous_acrobot       # "ContinuousAcrobot-v0"
import improved_hopper



from SnapshotENV import SnapshotEnv
from poly_hoo_module import POLY_HOO

# 1) environment IDs (now including all improved environments)
env_names = [
    "Continuous-CartPole-v0",
    "StochasticPendulum-v0",
    "StochasticMountainCarContinuous-v0",
    "StochasticContinuousAcrobot-v0",
    "ImprovedHopper-v0"
]

# 2) For each environment, define noise scales or other constructor kwargs
ENV_NOISE_CONFIG = {
    "Continuous-CartPole-v0": {
        "action_noise_scale": 0.05, #0.05
        "dynamics_noise_scale": 0.5, #0.01
        "obs_noise_scale": 0.0
    },
    "StochasticPendulum-v0": {
        "action_noise_scale": 0.02, #0.02,
        "dynamics_noise_scale": 0.1, #0.01,
        "obs_noise_scale": 0.01
        # or pass "g": 9.8 if you want a different gravity, etc.
    },
    "StochasticMountainCarContinuous-v0": {
        "action_noise_scale":  0.05, #0.03,
        "dynamics_noise_scale": 0.5, #0.01,
        "obs_noise_scale": 0.0
    },
    "StochasticContinuousAcrobot-v0": {
        "action_noise_scale": 0.05, #0.05,
        "dynamics_noise_scale": 0.7,  #0.01,
        "obs_noise_scale": 0.01
    },
    "ImprovedHopper-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    }
}

# 3) Global config
num_seeds = 20
TEST_ITERATIONS = 150
discount = 0.99
MAX_MCTS_DEPTH = 100
HOO_LIMIT_DEPTH = 10
power = 1.0

INF = 1e9

# Dimension-adaptive parameters for HOOT
def get_hoot_params(dim):
    """Get appropriate HOOT parameters based on action dimension"""
    if dim == 1:
        # Original parameters for 1D
        alphas = [5 for _ in range(MAX_MCTS_DEPTH + 1)]
        xis = [20.0 for _ in range(MAX_MCTS_DEPTH + 1)]
        etas = [0.5 for _ in range(MAX_MCTS_DEPTH + 1)]
    elif dim <= 3:
        # Adjusted for low-dimensional (2D, 3D)
        alphas = [3 for _ in range(MAX_MCTS_DEPTH + 1)]
        xis = [15.0 for _ in range(MAX_MCTS_DEPTH + 1)]
        etas = [0.4 for _ in range(MAX_MCTS_DEPTH + 1)]
    elif dim <= 8:
        # Adjusted for medium-dimensional (4D-8D)
        alphas = [2 for _ in range(MAX_MCTS_DEPTH + 1)]
        xis = [10.0 for _ in range(MAX_MCTS_DEPTH + 1)]
        etas = [0.3 for _ in range(MAX_MCTS_DEPTH + 1)]
    else:
        # Adjusted for high-dimensional (9D+)
        alphas = [1.5 for _ in range(MAX_MCTS_DEPTH + 1)]
        xis = [8.0 for _ in range(MAX_MCTS_DEPTH + 1)]
        etas = [0.25 for _ in range(MAX_MCTS_DEPTH + 1)]

    return alphas, xis, etas

# We'll do iteration counts in a geometric progression
base = 1000 ** (1.0 / 15.0)
samples = [int(3 * (base ** i)) for i in range(16)]
samples_to_use = samples[0:6]

# 4) Node class with dimension-adaptive parameters
class Node:
    def __init__(self, snapshot, obs, is_done, parent, depth, dim,
                 min_action, max_action, alpha_arr, xi_arr, eta_arr, p):
        self.parent = parent
        self.snapshot = snapshot
        self.obs = obs
        self.is_done = is_done
        self.children = {}
        self.immediate_reward = 0
        self.dim = dim
        self.depth = depth

        safe_depth = min(depth, MAX_MCTS_DEPTH)

        # Dimension-adaptive rho and nu
        if dim == 1:
            rho = 2 ** (-2 / dim)
            nu = 4 * dim
        elif dim <= 3:
            rho = 2 ** (-1.5 / dim)
            nu = 3 * dim
        elif dim <= 8:
            rho = 2 ** (-1 / dim)
            nu = 2 * dim
        else:
            rho = 2 ** (-0.8 / dim)
            nu = 1.5 * dim

        self.hoo = POLY_HOO(
            dim=dim, nu=nu, rho=rho,
            min_value=min_action,
            max_value=max_action,
            lim_depth=HOO_LIMIT_DEPTH,
            alpha=alpha_arr[safe_depth],
            xi=xi_arr[safe_depth],
            eta=eta_arr[safe_depth],
            p=p
        )

    def selection(self, env, depth, max_depth):
        if self.is_done or depth >= max_depth:
            return 0
        raw_action = self.hoo.select_action().tolist()
        action = tuple(float(x) for x in raw_action)

        if action in self.children:
            child = self.children[action]
            immediate_reward = child.immediate_reward
            value = child.selection(env, depth + 1, max_depth)
            self.hoo.update(value + immediate_reward)
            return immediate_reward + value
        else:
            snapshot, obs, immediate_reward, is_done, _ = env.get_result(self.snapshot, action)

            # Get dimension-adaptive parameters for child
            child_alphas, child_xis, child_etas = get_hoot_params(self.dim)

            child = Node(
                snapshot=snapshot,
                obs=obs,
                is_done=is_done,
                parent=self,
                depth=depth + 1,
                dim=self.dim,
                min_action=self.hoo.Tree.root.cell[0][0],
                max_action=self.hoo.Tree.root.cell[0][1],
                alpha_arr=child_alphas,
                xi_arr=child_xis,
                eta_arr=child_etas,
                p=power
            )
            child.immediate_reward = immediate_reward
            self.children[action] = child
            value = child.selection(env, depth + 1, max_depth)
            self.hoo.update(value + immediate_reward)
            return immediate_reward + value

    def delete(self, node):
        for act in node.children:
            node.delete(node.children[act])
        del node

# 5) Main script
if __name__ == "__main__":
    results_filename = "poly_hoot_results.txt"
    f_out = open(results_filename, "a")

    for envname in env_names:
        # Create environment with noise settings
        stoch_kwargs = ENV_NOISE_CONFIG.get(envname, {})
        base_env = gym.make(envname, **stoch_kwargs).env

        # Figure out dimension, action ranges, etc.
        if envname == "Continuous-CartPole-v0":
            min_action = base_env.min_action
            max_action = base_env.max_action
            dim = 1
            max_depth = 50
        elif envname == "StochasticPendulum-v0":
            min_action = -2.0
            max_action = 2.0
            dim = 1
            max_depth = 50
        elif envname == "StochasticMountainCarContinuous-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 1
            max_depth = 50
        elif envname == "StochasticContinuousAcrobot-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 1
            max_depth = 50
        elif envname == "ImprovedHopper-v0":
            min_action = -1.0
            max_action = 1.0
            dim = 3
            max_depth = 100
        else:
            min_action = -1.0
            max_action = 1.0
            dim = 1
            max_depth = 50

        print(f"\nEnvironment: {envname}")
        print(f"Action dimension: {dim}")
        print(f"Max depth: {max_depth}")

        # Get dimension-adaptive parameters
        alphas, xis, etas = get_hoot_params(dim)
        print(f"HOOT parameters - alpha: {alphas[0]}, xi: {xis[0]}, eta: {etas[0]}")

        # Wrap in SnapshotEnv
        planning_env = SnapshotEnv(gym.make(envname, **stoch_kwargs).env)
        root_obs_ori = planning_env.reset()
        root_snapshot_ori = planning_env.get_snapshot()

        for ITERATIONS in samples_to_use:
            seed_returns = []

            for seed_i in range(num_seeds):
                random.seed(seed_i)
                np.random.seed(seed_i)

                # copy original snapshot
                root_obs = copy.copy(root_obs_ori)
                root_snapshot = copy.copy(root_snapshot_ori)

                # Build tree
                root = Node(
                    snapshot=root_snapshot,
                    obs=root_obs,
                    is_done=False,
                    parent=None,
                    depth=0,
                    dim=dim,
                    min_action=min_action,
                    max_action=max_action,
                    alpha_arr=alphas,
                    xi_arr=xis,
                    eta_arr=etas,
                    p=power
                )

                # plan
                for _ in range(ITERATIONS):
                    root.selection(planning_env, depth=0, max_depth=max_depth)

                # test
                test_env = pickle.loads(root_snapshot)
                total_reward = 0.0
                current_discount = 1.0
                done = False

                for i in range(TEST_ITERATIONS):
                    best_action = root.hoo.get_point()
                    best_action = tuple(float(x) for x in best_action)

                    s, r, done, _ = test_env.step(best_action)
                    total_reward += r * current_discount
                    current_discount *= discount

                    if done:
                        test_env.close()
                        break

                    # prune other children
                    for act_key in list(root.children.keys()):
                        if act_key != best_action:
                            root.delete(root.children[act_key])

                    if best_action in root.children:
                        root = root.children[best_action]
                        root.depth = 0
                        # Re-plan with fewer iterations for computational efficiency
                        plan_iterations = ITERATIONS if dim <= 3 else max(ITERATIONS // 2, 100)
                        for _ in range(plan_iterations):
                            root.selection(planning_env, depth=0, max_depth=max_depth)
                    else:
                        # If action not in children, create new root-like state
                        # This shouldn't happen often, but handles edge cases
                        break

                if not done:
                    test_env.close()

                seed_returns.append(total_reward)

            mean_return = statistics.mean(seed_returns)
            std_return = statistics.pstdev(seed_returns)
            interval = 2.0 * std_return

            msg = (f"Env={envname}, ITER={ITERATIONS}: "
                   f"Mean={mean_return:.3f} ± {interval:.3f} "
                   f"(over {num_seeds} seeds)")
            print(msg)
            f_out.write(msg + "\n")
            f_out.flush()

    f_out.close()
    print(f"Done! Results saved to {results_filename}")
